/*
 Copyright (c) 2024 HigginsSoft, Alexander Higgins - https://github.com/alexhiggins732/ 

 Copyright (c) 2018, Brock Allen & Dominick Baier. All rights reserved.

 Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. 
 Source code and license this software can be found 

 The above copyright notice and this permission notice shall be included in all
 copies or substantial portions of the Software.
*/

#pragma warning disable 1591


namespace IdentityServer8.Extensions;

public static class HttpContextExtensions
{
    public static async Task<bool> GetSchemeSupportsSignOutAsync(this HttpContext context, string scheme)
    {
        var provider = context.RequestServices.GetRequiredService<IAuthenticationHandlerProvider>();
        var handler = await provider.GetHandlerAsync(context, scheme);
        return (handler is IAuthenticationSignOutHandler);
    }

    public static void SetIdentityServerOrigin(this HttpContext context, string value)
    {
        if (context == null) throw new ArgumentNullException(nameof(context));
        if (value == null) throw new ArgumentNullException(nameof(value));

        var split = value.Split(new[] { "://" }, StringSplitOptions.RemoveEmptyEntries);

        var request = context.Request;
        request.Scheme = split.First();
        request.Host = new HostString(split.Last());
    }

    public static void SetIdentityServerBasePath(this HttpContext context, string value)
    {
        if (context == null) throw new ArgumentNullException(nameof(context));

        context.Items[Constants.EnvironmentKeys.IdentityServerBasePath] = value;
    }

    public static string GetIdentityServerOrigin(this HttpContext context)
    {
        var options = context.RequestServices.GetRequiredService<IdentityServerOptions>();
        var request = context.Request;
        
        if (options.MutualTls.Enabled && options.MutualTls.DomainName.IsPresent())
        {
            if (!options.MutualTls.DomainName.Contains("."))
            {
                if (request.Host.Value.StartsWith(options.MutualTls.DomainName, StringComparison.OrdinalIgnoreCase))
                {
                    return request.Scheme + "://" +
                           request.Host.Value.Substring(options.MutualTls.DomainName.Length + 1);
                }
            }
        }
        
        return request.Scheme + "://" + request.Host.Value;
    }


    internal static void SetSignOutCalled(this HttpContext context)
    {
        if (context == null) throw new ArgumentNullException(nameof(context));
        context.Items[Constants.EnvironmentKeys.SignOutCalled] = "true";
    }

    internal static bool GetSignOutCalled(this HttpContext context)
    {
        return context.Items.ContainsKey(Constants.EnvironmentKeys.SignOutCalled);
    }

    /// <summary>
    /// Gets the host name of IdentityServer.
    /// </summary>
    /// <param name="context">The context.</param>
    /// <returns></returns>
    public static string GetIdentityServerHost(this HttpContext context)
    {
        var request = context.Request;
        return request.Scheme + "://" + request.Host.ToUriComponent();
    }

    /// <summary>
    /// Gets the base path of IdentityServer.
    /// </summary>
    /// <param name="context">The context.</param>
    /// <returns></returns>
    public static string GetIdentityServerBasePath(this HttpContext context)
    {
        return context.Items[Constants.EnvironmentKeys.IdentityServerBasePath] as string;
    }

    /// <summary>
    /// Gets the public base URL for IdentityServer.
    /// </summary>
    /// <param name="context">The context.</param>
    /// <returns></returns>
    public static string GetIdentityServerBaseUrl(this HttpContext context)
    {
        return context.GetIdentityServerHost() + context.GetIdentityServerBasePath();
    }

    /// <summary>
    /// Gets the identity server relative URL.
    /// </summary>
    /// <param name="context">The context.</param>
    /// <param name="path">The path.</param>
    /// <returns></returns>
    public static string GetIdentityServerRelativeUrl(this HttpContext context, string path)
    {
        if (!path.IsLocalUrl())
        {
            return null;
        }

        if (path.StartsWith("~/")) path = path.Substring(1);
        path = context.GetIdentityServerBaseUrl().EnsureTrailingSlash() + path.RemoveLeadingSlash();
        return path;
    }

    /// <summary>
    /// Gets the identity server issuer URI.
    /// </summary>
    /// <param name="context">The context.</param>
    /// <returns></returns>
    /// <exception cref="System.ArgumentNullException">context</exception>
    public static string GetIdentityServerIssuerUri(this HttpContext context)
    {
        if (context == null) throw new ArgumentNullException(nameof(context));

        // if they've explicitly configured a URI then use it,
        // otherwise dynamically calculate it
        var options = context.RequestServices.GetRequiredService<IdentityServerOptions>();
        var uri = options.IssuerUri;
        if (uri.IsMissing())
        {
            uri = context.GetIdentityServerOrigin() + context.GetIdentityServerBasePath();
            if (uri.EndsWith("/")) uri = uri.Substring(0, uri.Length - 1);
            if (options.LowerCaseIssuerUri)
            {
                uri = uri.ToLowerInvariant();
            }
        }

        return uri;
    }

    internal static async Task<string> GetIdentityServerSignoutFrameCallbackUrlAsync(this HttpContext context, LogoutMessage logoutMessage = null)
    {
        var userSession = context.RequestServices.GetRequiredService<IUserSession>();
        var user = await userSession.GetUserAsync();
        var currentSubId = user?.GetSubjectId();

        LogoutNotificationContext endSessionMsg = null;

        // if we have a logout message, then that take precedence over the current user
        if (logoutMessage?.ClientIds?.Any() == true)
        {
            var clientIds = logoutMessage?.ClientIds;

            // check if current user is same, since we might have new clients (albeit unlikely)
            if (currentSubId == logoutMessage?.SubjectId)
            {
                clientIds = clientIds.Union(await userSession.GetClientListAsync());
                clientIds = clientIds.Distinct();
            }

            endSessionMsg = new LogoutNotificationContext
            {
                SubjectId = logoutMessage.SubjectId,
                SessionId = logoutMessage.SessionId,
                ClientIds = clientIds
            };
        }
        else if (currentSubId != null)
        {
            // see if current user has any clients they need to signout of 
            var clientIds = await userSession.GetClientListAsync();
            if (clientIds.Any())
            {
                endSessionMsg = new LogoutNotificationContext
                {
                    SubjectId = currentSubId,
                    SessionId = await userSession.GetSessionIdAsync(),
                    ClientIds = clientIds
                };
            }
        }

        if (endSessionMsg != null)
        {
            var clock = context.RequestServices.GetRequiredService<ISystemClock>();
            var msg = new Message<LogoutNotificationContext>(endSessionMsg, clock.UtcNow.UtcDateTime);

            var endSessionMessageStore = context.RequestServices.GetRequiredService<IMessageStore<LogoutNotificationContext>>();
            var id = await endSessionMessageStore.WriteAsync(msg);

            var signoutIframeUrl = context.GetIdentityServerBaseUrl().EnsureTrailingSlash() + Constants.ProtocolRoutePaths.EndSessionCallback;
            signoutIframeUrl = signoutIframeUrl.AddQueryString(Constants.UIConstants.DefaultRoutePathParams.EndSessionCallback, id);

            return signoutIframeUrl;
        }

        // no sessions, so nothing to cleanup
        return null;
    }
}
